{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download model\n",
    "\n",
    "Download any models you prefer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://huseinhouse-storage.s3-ap-southeast-1.amazonaws.com/bert-bahasa/bert-bahasa-9-july-2019.tar.gz\n",
    "# !tar -zxf bert-bahasa-9-july-2019.tar.gz\n",
    "# !pip3 install bert-tensorflow --user"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "config.json\t\t\t\tmodel.ckpt-1000000.meta\r\n",
      "model.ckpt-1000000.data-00000-of-00001\tsp10m.cased.v4.model\r\n",
      "model.ckpt-1000000.index\t\tsp10m.cased.v4.vocab\r\n"
     ]
    }
   ],
   "source": [
    "!ls bert-bahasa-9-july-2019"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "import sentencepiece as spm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prepro_utils import preprocess_text, encode_ids, encode_pieces\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('bert-bahasa-9-july-2019/sp10m.cased.v4.model')\n",
    "\n",
    "with open('bert-bahasa-9-july-2019/sp10m.cased.v4.vocab') as fopen:\n",
    "    v = fopen.read().split('\\n')[:-1]\n",
    "v = [i.split('\\t') for i in v]\n",
    "v = {i[0]: i[1] for i in v}\n",
    "\n",
    "class Tokenizer:\n",
    "    def __init__(self, v):\n",
    "        self.vocab = v\n",
    "        pass\n",
    "    \n",
    "    def tokenize(self, string):\n",
    "        return encode_pieces(sp_model, string, return_unicode=False, sample=False)\n",
    "    \n",
    "    def convert_tokens_to_ids(self, tokens):\n",
    "        return [sp_model.PieceToId(piece) for piece in tokens]\n",
    "    \n",
    "    def convert_ids_to_tokens(self, ids):\n",
    "        return [sp_model.IdToPiece(i) for i in ids]\n",
    "    \n",
    "tokenizer = Tokenizer(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bert\n",
    "from bert import run_classifier\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "from bert import modeling\n",
    "import numpy as np\n",
    "import json\n",
    "import tensorflow as tf\n",
    "import itertools\n",
    "from unidecode import unidecode\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_INIT_CHKPNT = 'bert-bahasa-9-july-2019/model.ckpt-1000000'\n",
    "BERT_CONFIG = 'bert-bahasa-9-july-2019/config.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya-Dataset/master/subjectivity/subjectivity-negative-bm.txt\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya-Dataset/master/subjectivity/subjectivity-positive-bm.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('subjectivity-negative-bm.txt','r') as fopen:\n",
    "    texts = fopen.read().split('\\n')\n",
    "labels = [0] * len(texts)\n",
    "\n",
    "with open('subjectivity-positive-bm.txt','r') as fopen:\n",
    "    positive_texts = fopen.read().split('\\n')\n",
    "labels += [1] * len(positive_texts)\n",
    "texts += positive_texts\n",
    "\n",
    "assert len(labels) == len(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁yang',\n",
       " '▁muncul',\n",
       " '▁dari',\n",
       " '▁jiwa',\n",
       " '▁manusia',\n",
       " '▁dan',\n",
       " '▁menunjukkan',\n",
       " '▁ciri',\n",
       " '-',\n",
       " 'ciri',\n",
       " '▁',\n",
       " 'abstrak',\n",
       " '▁express',\n",
       " 'ion',\n",
       " 'ism',\n",
       " '▁',\n",
       " 'abstrak',\n",
       " '▁dan',\n",
       " '▁penyingkiran',\n",
       " '▁',\n",
       " 'grafi',\n",
       " 'ti',\n",
       " '▁kon',\n",
       " 's',\n",
       " 'truk',\n",
       " 'tiv',\n",
       " 'isme',\n",
       " '▁russian',\n",
       " '▁telah',\n",
       " '▁menguatkan',\n",
       " '▁tempatnya',\n",
       " '▁dalam',\n",
       " '▁sejarah',\n",
       " '▁seni',\n",
       " '▁moden',\n",
       " '▁ketika',\n",
       " '▁dicipta',\n",
       " '▁oleh',\n",
       " '▁artis',\n",
       " '▁yang',\n",
       " '▁tidak',\n",
       " '▁sedar',\n",
       " 'kan',\n",
       " '▁diri',\n",
       " '▁dengan',\n",
       " '▁pencapaian',\n",
       " '▁kesenian',\n",
       " '▁mereka']"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MAX_SEQ_LENGTH = 100\n",
    "tokenizer.tokenize(texts[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 9962/9962 [00:01<00:00, 7403.98it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "input_ids, input_masks, segment_ids = [], [], []\n",
    "\n",
    "for text in tqdm(texts):\n",
    "    tokens_a = tokenizer.tokenize(text)\n",
    "    if len(tokens_a) > MAX_SEQ_LENGTH - 2:\n",
    "        tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]\n",
    "    tokens = [\"<cls>\"] + tokens_a + [\"<sep>\"]\n",
    "    segment_id = [0] * len(tokens)\n",
    "    input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    input_mask = [1] * len(input_id)\n",
    "    padding = [0] * (MAX_SEQ_LENGTH - len(input_id))\n",
    "    input_id += padding\n",
    "    input_mask += padding\n",
    "    segment_id += padding\n",
    "    \n",
    "    input_ids.append(input_id)\n",
    "    input_masks.append(input_mask)\n",
    "    segment_ids.append(segment_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "batch_size = 60\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(texts) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        model = modeling.BertModel(\n",
    "            config=bert_config,\n",
    "            is_training=False,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.input_masks,\n",
    "            token_type_ids=self.segment_ids,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        output_layer = model.get_pooled_output()\n",
    "        self.logits = tf.layers.dense(output_layer, dimension_output)\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
    "                                                       num_train_steps, num_warmup_steps, False)\n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "\n",
      "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/jupyter/.local/lib/python3.6/site-packages/bert/modeling.py:671: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/learning_rate_decay_v2.py:321: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Deprecated in favor of operator or tf.math.divide.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n",
      "INFO:tensorflow:Restoring parameters from bert-bahasa-9-july-2019/model.ckpt-1000000\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jupyter/.local/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cross_validation import train_test_split\n",
    "\n",
    "train_input_ids, test_input_ids, train_input_masks, test_input_masks, train_segment_ids, test_segment_ids, train_Y, test_Y = train_test_split(\n",
    "    input_ids, input_masks, segment_ids, labels, test_size = 0.2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 133/133 [00:52<00:00,  2.74it/s, accuracy=0.918, cost=0.158]\n",
      "test minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  7.81it/s, accuracy=1, cost=0.0594]   \n",
      "train minibatch loop:   0%|          | 0/133 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.910186\n",
      "time taken: 56.43599534034729\n",
      "epoch: 0, training loss: 0.357670, training acc: 0.836630, valid loss: 0.275480, valid acc: 0.910186\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 133/133 [00:50<00:00,  2.74it/s, accuracy=0.959, cost=0.081] \n",
      "test minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  8.31it/s, accuracy=1, cost=0.0618]   \n",
      "train minibatch loop:   0%|          | 0/133 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, pass acc: 0.910186, current acc: 0.926242\n",
      "time taken: 54.693442583084106\n",
      "epoch: 1, training loss: 0.197910, training acc: 0.927036, valid loss: 0.275675, valid acc: 0.926242\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 133/133 [00:50<00:00,  2.72it/s, accuracy=0.98, cost=0.0683] \n",
      "test minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  8.27it/s, accuracy=0.846, cost=0.824]\n",
      "train minibatch loop:   0%|          | 0/133 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 54.85988974571228\n",
      "epoch: 2, training loss: 0.087594, training acc: 0.972616, valid loss: 0.569989, valid acc: 0.875950\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 133/133 [00:50<00:00,  2.72it/s, accuracy=1, cost=0.00078]   \n",
      "test minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  8.25it/s, accuracy=0.846, cost=0.53] \n",
      "train minibatch loop:   0%|          | 0/133 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.07711744308472\n",
      "epoch: 3, training loss: 0.035270, training acc: 0.990589, valid loss: 0.486300, valid acc: 0.915087\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 133/133 [00:50<00:00,  2.72it/s, accuracy=1, cost=0.000849]  \n",
      "test minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  8.24it/s, accuracy=1, cost=0.00546]  \n",
      "train minibatch loop:   0%|          | 0/133 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, pass acc: 0.926242, current acc: 0.934772\n",
      "time taken: 54.99451947212219\n",
      "epoch: 4, training loss: 0.013198, training acc: 0.997992, valid loss: 0.602157, valid acc: 0.934772\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 133/133 [00:50<00:00,  2.72it/s, accuracy=1, cost=1.38e-5]   \n",
      "test minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  8.23it/s, accuracy=0.923, cost=0.116]\n",
      "train minibatch loop:   0%|          | 0/133 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.04400396347046\n",
      "epoch: 5, training loss: 0.003051, training acc: 1.000376, valid loss: 0.753478, valid acc: 0.923926\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 133/133 [00:50<00:00,  2.72it/s, accuracy=1, cost=7.52e-6] \n",
      "test minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  8.25it/s, accuracy=0.923, cost=0.109]\n",
      "train minibatch loop:   0%|          | 0/133 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 55.0029354095459\n",
      "epoch: 6, training loss: 0.000175, training acc: 1.001380, valid loss: 0.736715, valid acc: 0.929445\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 133/133 [00:50<00:00,  2.72it/s, accuracy=1, cost=4.75e-6]\n",
      "test minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  8.26it/s, accuracy=0.923, cost=0.266]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 54.99624180793762\n",
      "epoch: 7, training loss: 0.000011, training acc: 1.001380, valid loss: 0.754906, valid acc: 0.926435\n",
      "\n",
      "break epoch:8\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 3, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_input_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_input_ids))\n",
    "        batch_x = train_input_ids[i: index]\n",
    "        batch_masks = train_input_masks[i: index]\n",
    "        batch_segment = train_segment_ids[i: index]\n",
    "        batch_y = train_Y[i: index]\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_input_ids))\n",
    "        batch_x = test_input_ids[i: index]\n",
    "        batch_masks = test_input_masks[i: index]\n",
    "        batch_segment = test_segment_ids[i: index]\n",
    "        batch_y = test_Y[i: index]\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    train_loss /= len(train_input_ids) / batch_size\n",
    "    train_acc /= len(train_input_ids) / batch_size\n",
    "    test_loss /= len(test_input_ids) / batch_size\n",
    "    test_acc /= len(test_input_ids) / batch_size\n",
    "\n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation minibatch loop: 100%|██████████| 34/34 [00:04<00:00,  7.83it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_input_ids), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    index = min(i + batch_size, len(test_input_ids))\n",
    "    batch_x = test_input_ids[i: index]\n",
    "    batch_masks = test_input_masks[i: index]\n",
    "    batch_segment = test_segment_ids[i: index]\n",
    "    batch_y = test_Y[i: index]\n",
    "    predict_Y += np.argmax(sess.run(model.logits,\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "    ), 1, ).tolist()\n",
    "    real_Y += batch_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "   negative    0.91466   0.89665   0.90557      1016\n",
      "   positive    0.89468   0.91300   0.90375       977\n",
      "\n",
      "avg / total    0.90487   0.90467   0.90468      1993\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "print(\n",
    "    metrics.classification_report(\n",
    "        real_Y, predict_Y, target_names = ['negative', 'positive'],digits=5\n",
    "    )\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
